/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
 * \file reduce_xor_sum_c310_impl.h
 * \brief
 */
#ifndef IMPL_REDUCE_REDUCE_XOR_SUM_C310_IMPL_H
#define IMPL_REDUCE_REDUCE_XOR_SUM_C310_IMPL_H
#include "kernel_tensor.h"
#include "kernel_operator_intf.h"
#include "kernel_tiling/kernel_tiling.h"

namespace AscendC {
namespace ReduceXorSumAPI {

constexpr MicroAPI::CastTrait castTraitF16F32 = {
    MicroAPI::RegLayout::ZERO, MicroAPI::SatMode::UNKNOWN, MicroAPI::MaskMergeMode::ZEROING, RoundMode::UNKNOWN};
constexpr MicroAPI::CastTrait castTraitF32F16 = {
    MicroAPI::RegLayout::ZERO, MicroAPI::SatMode::SAT, MicroAPI::MaskMergeMode::ZEROING, RoundMode::CAST_RINT};

template<typename T>
__simd_callee__ inline void LoadSrcDataAndXor(MicroAPI::RegTensor<float>& srcReg, __ubuf__ T* src0, 
    __ubuf__ T* src1, uint16_t index, MicroAPI::MaskReg& mask)
{
    constexpr uint16_t eleCountPerVL = GetVecLen() / sizeof(float);
    MicroAPI::RegTensor<T> src0TmpReg, src1TmpReg;
    MicroAPI::DataCopy<T, MicroAPI::LoadDist::DIST_UNPACK_B16>(src0TmpReg, src0 + index * eleCountPerVL);
    MicroAPI::DataCopy<T, MicroAPI::LoadDist::DIST_UNPACK_B16>(src1TmpReg, src1 + index * eleCountPerVL);
    MicroAPI::Xor(src0TmpReg, src0TmpReg, src1TmpReg, mask);
    MicroAPI::Cast<float, T, castTraitF16F32>(srcReg, src0TmpReg, mask);
}

template<typename T>
__simd_vf__ inline void ReduceXorSumCoreImpl(__ubuf__ T* dstUb, __ubuf__ T* src0Ub,
    __ubuf__ T* src1Ub,  uint32_t calCount)
{
    MicroAPI::MaskReg mask;
    MicroAPI::UnalignReg ureg;
    MicroAPI::RegTensor<T> dstReg;
    MicroAPI::RegTensor<float> srcTmpReg, dstTmpReg, tmpReg, zeroReg;
    MicroAPI::MaskReg fullMask = MicroAPI::CreateMask<float>();
    constexpr int32_t eleCountPerVL = GetVecLen() / sizeof(float);
    uint16_t repeatTimes = CeilDivision(calCount, eleCountPerVL);
    MicroAPI::Duplicate(zeroReg, 0, fullMask);
    for (uint16_t i = 0; i < repeatTimes; i++) {
        mask = MicroAPI::UpdateMask<float>(calCount);
        LoadSrcDataAndXor(srcTmpReg, src0Ub, src1Ub, i, mask);
        MicroAPI::ReduceSum(dstTmpReg, srcTmpReg, mask);
        MicroAPI::Add(zeroReg, zeroReg, dstTmpReg, mask);
    }
    MicroAPI::Cast<T, float, castTraitF32F16>(dstReg, zeroReg, fullMask);
    MicroAPI::Pack<uint16_t, uint32_t, MicroAPI::HighLowPart::LOWEST>(
        (MicroAPI::RegTensor<uint16_t>&)dstReg, (MicroAPI::RegTensor<uint32_t>&)dstReg);
    MicroAPI::DataCopyUnAlign(dstUb, dstReg, ureg, 1);
    MicroAPI::DataCopyUnAlignPost(dstUb, ureg, 0);
}
} // namespace ReduceXorSumAPI

template <typename T, bool isReuseSource = false>
__aicore__ inline void ReduceXorSumCheckParams(LocalTensor<T>& dstTensor, const LocalTensor<T>& src0Tensor,
    const LocalTensor<T>& src1Tensor, LocalTensor<uint8_t>& sharedTmpBuffer, const uint32_t calCount)
{
    static_assert(std::is_same<T, int16_t>::value, "ReduceXorSum only support int16_t data type on current device!");
    CheckTensorPos<T>(dstTensor, Hardware::UB, "dstTensor", "VECIN / VECCALC / VECOUT", "ReduceXorSum");
    CheckTensorPos<T>(src0Tensor, Hardware::UB, "src0Tensor", "VECIN / VECCALC / VECOUT", "ReduceXorSum");
    CheckTensorPos<T>(src1Tensor, Hardware::UB, "src1Tensor", "VECIN / VECCALC / VECOUT", "ReduceXorSum");
    CheckTensorPos<uint8_t>(sharedTmpBuffer,
        Hardware::UB, "sharedTmpBuffer", "VECIN / VECCALC / VECOUT", "ReduceXorSum");
    CheckCalCount(calCount, "calCount", src0Tensor, "src0Tensor", "ReduceXorSum");
    CheckCalCount(calCount, "calCount", src1Tensor, "src1Tensor", "ReduceXorSum");
}

template <typename T, bool isReuseSource = false>
__aicore__ inline void ReduceXorSumCompute(LocalTensor<T>& dstTensor, const LocalTensor<T>& src0Tensor,
    const LocalTensor<T>& src1Tensor, LocalTensor<uint8_t>& sharedTmpBuffer, const uint32_t calCount)
{
    // Only for AI Vector Core.
    if ASCEND_IS_AIC {
        return;
    }

    ReduceXorSumCheckParams<T, isReuseSource>(dstTensor, src0Tensor, src1Tensor, sharedTmpBuffer, calCount);
    __local_mem__ T *dstUb = (__local_mem__ T *)dstTensor.GetPhyAddr();
    __local_mem__ T *src0Ub = (__local_mem__ T *)src0Tensor.GetPhyAddr();
    __local_mem__ T *src1Ub = (__local_mem__ T *)src1Tensor.GetPhyAddr();
    ReduceXorSumAPI::ReduceXorSumCoreImpl<T>(dstUb, src0Ub, src1Ub, calCount);
}
} // namespace AscendC

#endif // IMPL_REDUCE_REDUCE_XOR_SUM_C310_IMPL_H